{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1020eef2",
   "metadata": {},
   "source": [
    "### 保存载入体系介绍"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47fbaab7",
   "metadata": {},
   "source": [
    ">在保存的时候下面的错，不知道什么原因：\n",
    ">\n",
    "><font color=\"red\">RuntimeError: (NotFound) No Input(Filter) found for Conv operator.\n",
    "  [Hint: Expected ctx->HasInput(\"Filter\") == true, but received ctx->HasInput(\"Filter\"):0 != true:1.] (at C:\\home\\workspace\\Paddle_release\\paddle/fluid/operators/conv_op.cc:42)\n",
    "  [operator < conv2d > error]</font>\n",
    ">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42ce4c62",
   "metadata": {},
   "source": [
    "#### 基础API保存载入体系"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28c70649",
   "metadata": {},
   "source": [
    ">训练场景：使用paddle.save/load\n",
    ">\n",
    ">推理部署：使用paddle.jie.save/load(动态图)；paddle.satic.save/load_inference_model\n",
    ">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b3218b3",
   "metadata": {},
   "source": [
    "#### 高层API保存载入体系"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19e3d604",
   "metadata": {},
   "source": [
    ">paddle.Model.fit(训练接口，同时带有保存参数功能)\n",
    ">\n",
    ">paddle.Model.save\n",
    ">\n",
    ">paddle.Model.load"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70de694d",
   "metadata": {},
   "source": [
    "### 训练调优场景的（模型、参数）保存与载入"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d238a15",
   "metadata": {},
   "source": [
    ">训练阶段主要保存模型的参数，当模型的训练时间很长时，为了防止意外中断需要保存模型训练的参数\n",
    ">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "177e6196",
   "metadata": {},
   "source": [
    "#### 动态图参数保存载入"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "166e21e3",
   "metadata": {},
   "source": [
    ">使用 paddle.save/load 结合Layer和Optimizer的state_dict达成目的，此处的state_dict是对象的持久化参数的载体，dict的key为参数名，value为参数真实值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8d5e7ba9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import paddle\n",
    "import paddle.nn as nn\n",
    "import paddle.optimizer as opt\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3a6b0102",
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 16\n",
    "BATCH_NUM = 4\n",
    "EPOCH_NUM = 4\n",
    "\n",
    "IMAGE_SIZE = 784\n",
    "CLASS_NUM = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1cffc7d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义一个随机数据集 \n",
    "class RandomDateset(paddle.io.Dataset):\n",
    "    def __init__(self,num_samples):\n",
    "        self.num_samples = num_samples\n",
    "        \n",
    "    def __getitem__(self,idx):\n",
    "        image = np.random.random([IMAGE_SIZE]).astype('float32')\n",
    "        label = np.random.randint(0,CLASS_NUM-1,(1,)).astype('int64')\n",
    "        return image,label\n",
    "    \n",
    "    def __len__(self):\n",
    "        return self.num_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b8278a48",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 子类组网\n",
    "class LinearNet(nn.Layer):\n",
    "    def __init__(self):\n",
    "        super(LinearNet,self).__init__()\n",
    "        self._linear = nn.Linear(IMAGE_SIZE,CLASS_NUM)\n",
    "        \n",
    "    def forward(self,x):\n",
    "        return self._linear(x)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fb127345",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义训练函数\n",
    "def train(layer,loader,loass_fn,opt):\n",
    "    for epoch_id in range(EPOCH_NUM):\n",
    "        for batch_id,data in enumerate(loader()):\n",
    "            x = data[0]\n",
    "            y = data[1]\n",
    "            predicts = layer(x)\n",
    "            loss = loss_fn(predicts,y)\n",
    "            acc = paddle.metric.accuracy(predicts,y)\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "            opt.clear_grad()\n",
    "            print(f\"epoch_id:{epoch_id},batch_id:{batch_id},loss:{loss.numpy()},acc:{acc.numpy()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f8c7aa42",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义网络、损失函数，优化器，精确度\n",
    "layer = LinearNet()\n",
    "loss_fn = paddle.nn.CrossEntropyLoss()\n",
    "adam = paddle.optimizer.Adam(learning_rate=0.01,parameters=layer.parameters())\n",
    "# 创建数据加载器，开始训练\n",
    "dataset = RandomDateset(BATCH_NUM*BATCH_SIZE)\n",
    "loader = paddle.io.DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True,drop_last=True,num_workers=2)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a5b3d15b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch_id:0,batch_id:0,loss:[2.5168338],acc:[0.0625]\n",
      "epoch_id:0,batch_id:1,loss:[4.1976337],acc:[0.125]\n",
      "epoch_id:0,batch_id:2,loss:[4.9551334],acc:[0.125]\n",
      "epoch_id:0,batch_id:3,loss:[3.608651],acc:[0.0625]\n",
      "epoch_id:1,batch_id:0,loss:[4.1102242],acc:[0.0625]\n",
      "epoch_id:1,batch_id:1,loss:[3.0134223],acc:[0.0625]\n",
      "epoch_id:1,batch_id:2,loss:[3.402636],acc:[0.125]\n",
      "epoch_id:1,batch_id:3,loss:[3.0109348],acc:[0.25]\n",
      "epoch_id:2,batch_id:0,loss:[3.4844694],acc:[0.]\n",
      "epoch_id:2,batch_id:1,loss:[5.342682],acc:[0.0625]\n",
      "epoch_id:2,batch_id:2,loss:[2.6492777],acc:[0.125]\n",
      "epoch_id:2,batch_id:3,loss:[2.6705422],acc:[0.0625]\n",
      "epoch_id:3,batch_id:0,loss:[2.9829288],acc:[0.1875]\n",
      "epoch_id:3,batch_id:1,loss:[2.3503506],acc:[0.25]\n",
      "epoch_id:3,batch_id:2,loss:[2.9422102],acc:[0.125]\n",
      "epoch_id:3,batch_id:3,loss:[2.7587576],acc:[0.125]\n"
     ]
    }
   ],
   "source": [
    "train(layer,loader,loss_fn,adam)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "33f05c7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 保存参数\n",
    "paddle.save(layer.state_dict(),\"linear_net.pdparams\")\n",
    "paddle.save(adam.state_dict(),\"adam.pdopt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0c06fe98",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 载入参数\n",
    "layer_state_dict = paddle.load(\"linear_net.pdparams\")\n",
    "opt_state_dict = paddle.load(\"adam.pdopt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e92f8da6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用下面方法将参数更新到模型中\n",
    "layer.set_state_dict(layer_state_dict)\n",
    "adam.set_state_dict(opt_state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5adb66c4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a0cced5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e59eaa03",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daa8a41e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f608077b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e1b93ccf",
   "metadata": {},
   "source": [
    "### 中断训练与恢复训练案例"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7fc22dd",
   "metadata": {},
   "source": [
    ">基于手写数字识别模型学习paddle如何保存及加载模型\n",
    ">\n",
    ">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dba50cae",
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "import paddle.nn.functional as F\n",
    "from paddle.nn import Layer\n",
    "from paddle.vision.datasets import MNIST\n",
    "from paddle.metric import Accuracy\n",
    "from paddle.nn import Conv2D,MaxPool2D,Linear\n",
    "from paddle.static import InputSpec\n",
    "from paddle.vision.transforms import ToTensor\n",
    "\n",
    "paddle.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a044aadb",
   "metadata": {},
   "source": [
    "#### 加载数据，数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08bf17dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = MNIST(mode=\"train\",transform=ToTensor())\n",
    "test_dataset = MNIST(mode=\"test\",transform=ToTensor())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b49adba",
   "metadata": {},
   "source": [
    "#### 构建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46bba20a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyModel(Layer):\n",
    "    def __init__(self):\n",
    "        super(MyModel,self).__init__()\n",
    "        self.conv1 = Conv2D(in_channels=1,out_channels=6,kernel_size=5,stride=1,padding=2)\n",
    "        self.max_pool1 = MaxPool2D(kernel_size=2,stride=2)\n",
    "        self.conv2 = Conv2D(in_channels=6,out_channels=16,kernel_size=5,stride=1)\n",
    "        self.max_pool2 = MaxPool2D(kernel_size=2,stride=2)\n",
    "        self.linear1 = Linear(in_features=16*5*5,out_features=120)\n",
    "        self.linear2 = Linear(in_features=120,out_features=84)\n",
    "        self.linear3 = Linear(in_features=84,out_features=10)\n",
    "        \n",
    "    def forward(self,x):\n",
    "        x = self.conv1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.max_pool1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.conv2(x)\n",
    "        x = self.max_pool2(x)\n",
    "        x = paddle.flatten(x,start_axis=1,stop_axis=-1)\n",
    "        x = self.linear1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.linear2(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.linear3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11568c99",
   "metadata": {},
   "source": [
    "#### 模型训练配置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01388842",
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = InputSpec([None,784],'float32','inputs')\n",
    "labels = InputSpec([None,10],'float32','labels')\n",
    "model = paddle.Model(MyModel(),inputs,labels)\n",
    "\n",
    "optim = paddle.optimizer.Adam(learning_rate=0.001,parameters=model.parameters())\n",
    "model.prepare(\n",
    "    optim,\n",
    "    paddle.nn.CrossEntropyLoss(),\n",
    "    Accuracy()\n",
    ")\n",
    "\n",
    "model.fit(\n",
    "    train_dataset,\n",
    "    test_dataset,\n",
    "    epochs=3,\n",
    "    batch_size=64,\n",
    "    save_dir='mnist_checkpoint',\n",
    "    verbose=1\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7223fc4",
   "metadata": {},
   "source": [
    "#### 上述模型在第2次训练被中断，现在使用load加载已训练完成的参数，继续训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea47ad71",
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = InputSpec([None,784],'float32','x')\n",
    "labels = InputSpec([None,10],'float32','label')\n",
    "model = paddle.Model(MyModel(),inputs,labels)\n",
    "\n",
    "optim = paddle.optimizer.Adam(learning_rate=0.001,parameters=model.parameters())\n",
    "\n",
    "# load 方法中的 path 为参数或优化器信息文件的前缀\n",
    "model.load(path=\"./mnist_checkpoint/0\")\n",
    "model.prepare(\n",
    "    optim,\n",
    "    paddle.nn.CrossEntropyLoss(),\n",
    "    Accuracy()\n",
    ")\n",
    "\n",
    "model.fit(\n",
    "    train_dataset,\n",
    "    test_dataset,\n",
    "    # epochs 改为2，因为已经训练完成一次\n",
    "    epochs=2,\n",
    "    batch_size=64,\n",
    "    save_dir='mnist_checkpoint',\n",
    "    verbose=1\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8e43319",
   "metadata": {},
   "source": [
    "可以看到 loss 和 acc 与中断之前的数据相差不大"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "391a590a",
   "metadata": {},
   "source": [
    "#### 保存模型的参数、优化器信息以及推理部署需要的参数与文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11b2cb89",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 保存训练的参数\n",
    "# model.save(\"mnist_model/test\")\n",
    "\n",
    "# 保存推理部署需要的参数及文件\n",
    "model.save(\"mnist_model/test\",training=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8c2b71b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1ed505a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2112b1d9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
